{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Application of FL task\n",
    "from MLModel import *\n",
    "from FLModel import *\n",
    "from utils import *\n",
    "\n",
    "from torchvision import datasets, transforms\n",
    "import torch\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "#device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_mnist(num_users):\n",
    "    train = datasets.MNIST(root=\"~/data/\", train=True, download=True, transform=transforms.ToTensor())\n",
    "    train_data = train.data.float()\n",
    "    train_label = train.targets\n",
    "\n",
    "    mean = train_data.mean()\n",
    "    std = train_data.std()\n",
    "    train_data = (train_data - mean) / std\n",
    "\n",
    "    test = datasets.MNIST(root=\"~/data/\", train=False, download=True, transform=transforms.ToTensor())\n",
    "    test_data = test.data.float()\n",
    "    test_label = test.targets\n",
    "    test_data = (test_data - mean) / std\n",
    "\n",
    "    # split MNIST (training set) into non-iid data sets\n",
    "    non_iid = []\n",
    "    user_dict = mnist_noniid(train_label, num_users)\n",
    "    for i in range(num_users):\n",
    "        idx = user_dict[i]\n",
    "        d = train_data[idx].flatten(1)\n",
    "        targets = train_label[idx].float()\n",
    "        non_iid.append((d, targets))\n",
    "    non_iid.append((test_data.flatten(1).float(), test_label.float()))\n",
    "    return non_iid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "1. load_data\n",
    "2. generate clients (step 3)\n",
    "3. generate aggregator\n",
    "4. training\n",
    "\"\"\"\n",
    "client_num = 5\n",
    "d = load_mnist(client_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DP-SGD with sampling rate = 5% and noise_multiplier = 6.427124318662883 iterated over 40000 steps satisfies differential privacy with eps = 8 and delta = 1e-05.\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "FL model parameters.\n",
    "\"\"\"\n",
    "lr = 0.1\n",
    "fl_param = {\n",
    "    'output_size': 10,\n",
    "    'client_num': client_num,\n",
    "    'model': MLP,\n",
    "    'data': d,\n",
    "    'lr': lr,\n",
    "    'E': 100,\n",
    "    'C': 1,\n",
    "    'eps': 8.0,\n",
    "    'delta': 1e-5,\n",
    "    'q': 0.05,\n",
    "    'clip': 4,\n",
    "    'tot_T': 20,\n",
    "    'batch_size': 64,\n",
    "    'device': device\n",
    "}\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "fl_entity = FLServer(fl_param).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "global epochs = 1, acc = 0.6658\n",
      "global epochs = 2, acc = 0.7678\n",
      "global epochs = 3, acc = 0.8175\n",
      "global epochs = 4, acc = 0.8521\n",
      "global epochs = 5, acc = 0.8594\n",
      "global epochs = 6, acc = 0.8587\n",
      "global epochs = 7, acc = 0.8657\n",
      "global epochs = 8, acc = 0.8746\n",
      "global epochs = 9, acc = 0.8760\n",
      "global epochs = 10, acc = 0.8813\n",
      "global epochs = 11, acc = 0.8833\n",
      "global epochs = 12, acc = 0.8887\n",
      "global epochs = 13, acc = 0.8915\n",
      "global epochs = 14, acc = 0.8926\n",
      "global epochs = 15, acc = 0.8956\n",
      "global epochs = 16, acc = 0.8959\n",
      "global epochs = 17, acc = 0.8965\n",
      "global epochs = 18, acc = 0.8967\n",
      "global epochs = 19, acc = 0.8974\n",
      "global epochs = 20, acc = 0.8991\n"
     ]
    }
   ],
   "source": [
    "acc = []\n",
    "for t in range(fl_param['tot_T']):\n",
    "    fl_entity.set_lr(lr/np.sqrt(t+1))\n",
    "    acc += [fl_entity.global_update()]\n",
    "    print(\"global epochs = {:d}, acc = {:.4f}\".format(t+1, acc[-1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
